iT邦幫忙

2024 iThome 鐵人賽

DAY 29
0
Software Development

LSTM結合Yolo v8對於多隻斑馬魚行為分析系列 第 29

day 29 lstm結合yolo v8對於多隻斑馬魚行為分析

  • 分享至 

  • xImage
  •  

今天是第二十九天我們可以寫一個lstm結合yolo v8對於多隻斑馬魚行為分析的最終版本,我認為是我寫得最有效率的程式碼,以下是程式碼

import torch
import cv2
import numpy as np
from sort import Sort
from tensorflow.keras.models import load_model
from concurrent.futures import ThreadPoolExecutor, as_completed
import sqlite3
import logging

# 設置日誌
logging.basicConfig(filename='zebrafish_analysis.log', level=logging.INFO, 
                    format='%(asctime)s:%(levelname)s:%(message)s')

# YOLOv8 檢測模組
class YOLOv8Detector:
    def __init__(self):
        self.model = torch.hub.load('ultralytics/yolov8', 'yolov8n')
    
    def detect(self, frame):
        results = self.model(frame)
        return results.xyxy[0]

# 追蹤模組
class FishTracker:
    def __init__(self):
        self.tracker = Sort()

    def track(self, detections):
        tracked_objects = self.tracker.update(detections.cpu())
        return tracked_objects

# 特徵提取模組
class FeatureExtractor:
    def extract(self, tracks):
        features = []
        for track in tracks:
            track_id, x_min, y_min, x_max, y_max = track
            center_x = (x_min + x_max) / 2
            center_y = (y_min + y_max) / 2
            width = x_max - x_min
            height = y_max - y_min
            
            # 高階特徵提取,如速度變化率等
            speed = np.sqrt((width ** 2 + height ** 2))
            features.append([track_id, center_x, center_y, width, height, speed])
        
        return np.array(features)

# 行為分類模組
class BehaviorClassifier:
    def __init__(self, model_path):
        self.lstm_model = load_model(model_path)
    
    def classify(self, features):
        predictions = self.lstm_model.predict(features)
        return np.argmax(predictions, axis=1)

# 結果儲存模組
class ResultSaver:
    def __init__(self, db_path='zebrafish_results.db'):
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.cursor.execute('''CREATE TABLE IF NOT EXISTS results 
                               (track_id INTEGER, behavior TEXT, timestamp TEXT)''')
        self.conn.commit()

    def save(self, track_id, behavior, timestamp):
        self.cursor.execute("INSERT INTO results (track_id, behavior, timestamp) VALUES (?, ?, ?)", 
                            (track_id, behavior, timestamp))
        self.conn.commit()

    def close(self):
        self.conn.close()

# 主分析模組
class ZebrafishAnalyzer:
    def __init__(self):
        self.detector = YOLOv8Detector()
        self.tracker = FishTracker()
        self.extractor = FeatureExtractor()
        self.classifier = BehaviorClassifier('/path_to_your_model/lstm_model.h5')
        self.saver = ResultSaver()

    def analyze_frame(self, frame, timestamp):
        detections = self.detector.detect(frame)
        tracks = self.tracker.track(detections)
        features = self.extractor.extract(tracks)

        if len(features) > 0:
            behaviors = self.classifier.classify(features)
            for i, track in enumerate(tracks):
                track_id = track[0]
                behavior = behaviors[i]
                self.saver.save(track_id, behavior, timestamp)
                logging.info(f'Track ID: {track_id}, Behavior: {behavior}, Timestamp: {timestamp}')
                cv2.putText(frame, f'Behavior: {behavior}', (int(track[1]), int(track[2]) - 10), 
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
                cv2.rectangle(frame, (int(track[1]), int(track[2])), 
                              (int(track[3]), int(track[4])), (255, 0, 0), 2)

        return frame

    def close(self):
        self.saver.close()

# 並行處理設計
def process_video(video_path):
    analyzer = ZebrafishAnalyzer()
    cap = cv2.VideoCapture(video_path)

    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            
            timestamp = cap.get(cv2.CAP_PROP_POS_MSEC)
            futures.append(executor.submit(analyzer.analyze_frame, frame, timestamp))
            
            if len(futures) > 10:
                for future in as_completed(futures):
                    frame = future.result()
                    cv2.imshow('Zebrafish Behavior Analysis', frame)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        cap.release()
                        cv2.destroyAllWindows()
                        analyzer.close()
                        return
        
        for future in as_completed(futures):
            frame = future.result()
            cv2.imshow('Zebrafish Behavior Analysis', frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    cap.release()
    cv2.destroyAllWindows()
    analyzer.close()

# 運行視頻處理
process_video('zebrafish_video.mp4')

1. 模組化設計

我們將程式碼分為幾個模組來分別處理不同的任務:檢測(Detection)、追蹤(Tracking)、特徵提取(Feature Extraction)、行為分類(Behavior Classification)、結果儲存(Result Saving)以及主分析模組(Main Analysis)。這樣設計可以提高程式碼的可維護性,讓各部分可以單獨開發和測試。

1.1 YOLOv8 檢測模組 (YOLOv8Detector)

這個模組負責使用YOLOv8模型來檢測每一幀視頻中的斑馬魚位置。

class YOLOv8Detector:
    def __init__(self):
        self.model = torch.hub.load('ultralytics/yolov8', 'yolov8n')
    
    def detect(self, frame):
        results = self.model(frame)
        return results.xyxy[0]
  • __init__: 初始化時載入YOLOv8模型。
  • detect: 給定一幀視頻圖片,使用模型來檢測並返回所有斑馬魚的位置框。

1.2 追蹤模組 (FishTracker)

這個模組使用Kalman Filter或DeepSORT來追蹤每一幀中已檢測到的斑馬魚。

class FishTracker:
    def __init__(self):
        self.tracker = Sort()

    def track(self, detections):
        tracked_objects = self.tracker.update(detections.cpu())
        return tracked_objects
  • __init__: 初始化時創建追蹤器對象(這裡使用的是Sort算法)。
  • track: 接受YOLOv8檢測到的位置,並更新追蹤器以返回每隻魚的追蹤結果。

1.3 特徵提取模組 (FeatureExtractor)

這個模組從追蹤的斑馬魚軌跡中提取高階特徵,如位置、速度、尺寸等。

class FeatureExtractor:
    def extract(self, tracks):
        features = []
        for track in tracks:
            track_id, x_min, y_min, x_max, y_max = track
            center_x = (x_min + x_max) / 2
            center_y = (y_min + y_max) / 2
            width = x_max - x_min
            height = y_max - y_min
            
            # 高階特徵提取,如速度變化率等
            speed = np.sqrt((width ** 2 + height ** 2))
            features.append([track_id, center_x, center_y, width, height, speed])
        
        return np.array(features)
  • extract: 根據追蹤器的結果,提取斑馬魚的高階特徵,包括位置(中心點)、尺寸(寬度、高度)和速度(計算出的速度值)。

1.4 行為分類模組 (BehaviorClassifier)

這個模組使用事先訓練好的LSTM模型來根據特徵分類斑馬魚的行為。

class BehaviorClassifier:
    def __init__(self, model_path):
        self.lstm_model = load_model(model_path)
    
    def classify(self, features):
        predictions = self.lstm_model.predict(features)
        return np.argmax(predictions, axis=1)
  • __init__: 載入LSTM模型,用於分類行為。
  • classify: 根據提取的特徵進行行為分類,並返回預測的行為標籤。

1.5 結果儲存模組 (ResultSaver)

這個模組將分析結果儲存在SQLite數據庫中,以便日後查詢和分析。

class ResultSaver:
    def __init__(self, db_path='zebrafish_results.db'):
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.cursor.execute('''CREATE TABLE IF NOT EXISTS results 
                               (track_id INTEGER, behavior TEXT, timestamp TEXT)''')
        self.conn.commit()

    def save(self, track_id, behavior, timestamp):
        self.cursor.execute("INSERT INTO results (track_id, behavior, timestamp) VALUES (?, ?, ?)", 
                            (track_id, behavior, timestamp))
        self.conn.commit()

    def close(self):
        self.conn.close()
  • __init__: 連接到SQLite數據庫並創建儲存結果的表格。
  • save: 儲存每一隻斑馬魚的行為預測結果以及時間戳到數據庫。
  • close: 關閉數據庫連接。

1.6 主分析模組 (ZebrafishAnalyzer)

這個模組將所有其他模組整合起來,處理每一幀視頻,完成斑馬魚行為分析的全過程。

class ZebrafishAnalyzer:
    def __init__(self):
        self.detector = YOLOv8Detector()
        self.tracker = FishTracker()
        self.extractor = FeatureExtractor()
        self.classifier = BehaviorClassifier('/path_to_your_model/lstm_model.h5')
        self.saver = ResultSaver()

    def analyze_frame(self, frame, timestamp):
        detections = self.detector.detect(frame)
        tracks = self.tracker.track(detections)
        features = self.extractor.extract(tracks)

        if len(features) > 0:
            behaviors = self.classifier.classify(features)
            for i, track in enumerate(tracks):
                track_id = track[0]
                behavior = behaviors[i]
                self.saver.save(track_id, behavior, timestamp)
                logging.info(f'Track ID: {track_id}, Behavior: {behavior}, Timestamp: {timestamp}')
                cv2.putText(frame, f'Behavior: {behavior}', (int(track[1]), int(track[2]) - 10), 
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
                cv2.rectangle(frame, (int(track[1]), int(track[2])), 
                              (int(track[3]), int(track[4])), (255, 0, 0), 2)

        return frame

    def close(self):
        self.saver.close()
  • __init__: 初始化各個模組。
  • analyze_frame: 對單幀圖片進行全過程的分析,檢測、追蹤、提取特徵、分類行為並儲存結果,還會在影像上標註行為。
  • close: 在分析結束後關閉儲存模組。

2. 並行處理

這部分使用了Python的ThreadPoolExecutor來進行並行處理,使得多幀的處理可以同時進行,提高效率。

def process_video(video_path):
    analyzer = ZebrafishAnalyzer()
    cap = cv2.VideoCapture(video_path)

    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            
            timestamp = cap.get(cv2.CAP_PROP_POS_MSEC)
            futures.append(executor.submit(analyzer.analyze_frame, frame, timestamp))
            
            if len(futures) > 10:
                for future in as_completed(futures):
                    frame = future.result()
                    cv2.imshow('Zebrafish Behavior Analysis', frame)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        cap.release()
                        cv2.destroyAllWindows()
                        analyzer.close()
                        return
        
        for future in as_completed(futures):
            frame = future.result()
            cv2.imshow('Zebrafish Behavior Analysis', frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    cap.release()
    cv2.destroyAllWindows()
    analyzer.close()
  • process_video: 處理整個視頻,使用多線程來並行處理多幀影像,當累積一定數量的處理結果後就展示在視窗中。

3. 高階特徵提取

在特徵提取部分,除了基本的空間特徵外,還加入了速度這樣的動態特徵,這些特徵有助於更準確的行為分類。

4. 結果儲存

使用SQLite數據庫來儲存結果,這樣可以方便地查詢和進行後續的統計分析。

5. 日誌記錄

日誌記錄可以幫助跟蹤程式的運行狀態和結果,特別是在處理大批量數據時非常有用。

這整個程式是一個高度模組化、功能豐富的斑馬魚行為分析系統,適用於科研或高階應用場景。


上一篇
day 28 Lstm結合yolo v8 多隻斑馬魚行為分析效能評估
系列文
LSTM結合Yolo v8對於多隻斑馬魚行為分析29
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言